#!/usr/bin/env python3

# Minimal RISC-V interpreter, RV32IM + Zicsr only, with trace disassembly

import argparse
import sys

XLEN = 32
XLEN_MASK = (1 << XLEN) - 1

def extract(bits, msb, lsb):
	return (bits & (1 << msb + 1) - 1) >> lsb

def sext(bits, sign_bit):
	return (bits & (1 << sign_bit + 1) - 1) - ((bits & 1 << sign_bit) << 1)

def concat_extract(bits, msb_lsb_pairs, signed=True):
	accum = 0
	accum_count = 0
	for msb, lsb in msb_lsb_pairs:
		accum = (accum << (msb - lsb + 1)) | extract(bits, msb, lsb)
		accum_count += msb - lsb + 1
	if signed:
		accum = sext(accum, accum_count - 1)
	return accum

# Note these handy functions are not used much in the main loop, because CPython is unable
# to inline them. This and similar changes results in a ~3x performance increase. :(
def imm_i(instr):
	# return concat_extract(instr, ((31, 20),))
	return (instr >> 20) - (instr >> 19 & 0x1000)

def imm_s(instr):
	# return concat_extract(instr, ((31, 25), (11, 7)))
	return (instr >> 20 & 0xfe0) + (instr >> 7 & 0x1f) - (instr >> 19 & 0x1000)

def imm_u(instr):
	# return concat_extract(instr, ((31, 12),)) << 12
	return instr & 0xfffff000 - (instr << 1 & 0x100000000)

def imm_b(instr):
	return concat_extract(instr, ((31, 31), (7, 7), (30, 25), (11, 8))) << 1

def imm_j(instr):
	return concat_extract(instr, ((31, 31), (19, 12), (20, 20), (30, 21))) << 1


class FlatMemory:

	def __init__(self, size):
		self.size = size
		self.mem = [0] * (size >> 2)

	# Reads are unsigned. Writes allow signed or unsigned values and convert
	# implicitly to unsigned. Multi-byte accesses are little-endian.

	def get8(self, addr):
		assert(addr >= 0 and addr < self.size)
		return self.mem[addr >> 2] >> (addr & 0x3) * 8 & 0xff

	def put8(self, addr, data):
		assert(addr >= 0 and addr < self.size)
		assert(data >= -1 << 7 and data < 1 << 8)
		self.mem[addr >> 2] &= ~(0xff << 8 * (addr & 0x3))
		self.mem[addr >> 2] |= (data & 0xff) << 8 * (addr % 4)

	def get16(self, addr):
		return self.mem[addr >> 2] >> (addr & 0x2) * 8 & 0xffff

	def put16(self, addr, data):
		assert(data >= -1 << 15 and data < 1 << 16)
		for i in range(2):
			self.put8(addr + i, data >> 8 * i & 0xff)

	def get32(self, addr):
		assert(addr >= 0 and addr + 3 < self.size)
		return self.mem[addr >> 2]

	def put32(self, addr, data):
		assert(data >= -1 << 31 and data < 1 << 32)
		assert(addr >= 0 and addr + 3 < self.size)
		self.mem[addr >> 2] = data & 0xffff_ffff

	def loadbin(self, data, offs):
		if type(data) not in (bytes, bytearray):
			# must be fh
			assert(data.mode == "rb")
			data = data.read()
		assert(offs + len(data) < self.size)
		for i, b in enumerate(data):
			self.put8(offs + i, b)

class TBExit(Exception):
	pass

class MemWithTBIO(FlatMemory):

	TB_IO_BASE = 0x80000000
	TB_IO_PRINT_CHAR = TB_IO_BASE + 0x0
	TB_IO_PRINT_INT = TB_IO_BASE + 0x4
	TB_IO_EXIT = TB_IO_BASE + 0x8

	def __init__(self, size, io_log_fmt="IO: {}\n"):
		super().__init__(size)
		self.io_log_fmt = io_log_fmt

	def put32(self, addr, data):
		if addr < self.TB_IO_BASE:
			super().put32(addr, data)
		elif addr == self.TB_IO_PRINT_CHAR:
			sys.stdout.write(self.io_log_fmt.format(chr(data)))
		elif addr == self.TB_IO_PRINT_INT:
			sys.stdout.write(self.io_log_fmt.format(f"{data:08x}\n"))
		elif addr == self.TB_IO_EXIT:
			raise TBExit(data)
		else:
			print(f"Unknown IO address {addr:08x}")

class RVCSR:

	WRITE       = 0
	WRITE_SET   = 1
	WRITE_CLEAR = 2

	MSCRATCH    = 0x340
	MCYCLE      = 0xb00
	MTIME       = 0xb01
	MINSTRET    = 0xb02

	def __init__(self):
		self.mcycle = 0
		self.minstret = 0
		self.mscratch = 0

	def step(self):
		self.mcycle += 1
		self.minstret += 1

	def read(self, addr, side_effect=True):
		if addr in (RVCSR.MCYCLE, RVCSR.MTIME):
			return self.mcycle
		elif addr == RVCSR.MINSTRET:
			return self.minstret
		elif addr == RVCSR.MSCRATCH:
			return self.mscratch
		else:
			return 0

	def write(self, addr, data, op=0):
		if op == RVCSR.WRITE_CLEAR:
			data = self.read(addr, side_effect=False) & ~data
		elif op == RVCSR.WRITE_SET:
			data = self.read(addr, side_effect=False) | data
		if addr == RVCSR.MCYCLE:
			self.mcycle = data
		elif addr == RVCSR.MINSTRET:
			self.minstret = data
		elif addr == RVCSR.MSCRATCH:
			self.mscratch = data

class RVCore:

	def __init__(self, mem, reset_vector=0x40):
		self.regs = [0] * 32
		self.mem = mem
		self.pc = reset_vector
		self.csr = RVCSR()
		self.stage3_result = None

		self.btb_valid = False
		self.btb_pc = 0

	def step(self, instr=None, log=True, cycle_accurate=True):
		if instr is None:
			instr = self.mem.mem[self.pc >> 2]
		regnum_rs1 = instr >> 15 & 0x1f
		regnum_rs2 = instr >> 20 & 0x1f
		regnum_rd  = instr >> 7 & 0x1f
		rs1 = self.regs[regnum_rs1]
		rs2 = self.regs[regnum_rs2]

		rd_wdata = None
		pc_wdata = None
		log_disasm = None
		instr_invalid = False
		stall_cycles = 0
		stage3_result_next = None

		opc = instr >> 2 & 0x1f
		funct3 = instr >> 12 & 0x7
		funct7 = instr >> 25 & 0x7f
		OPC_LOAD     = 0b00_000
		OPC_MISC_MEM = 0b00_011
		OPC_OP_IMM   = 0b00_100
		OPC_AUIPC    = 0b00_101
		OPC_STORE    = 0b01_000
		OPC_OP       = 0b01_100
		OPC_LUI      = 0b01_101
		OPC_BRANCH   = 0b11_000
		OPC_JALR     = 0b11_001
		OPC_JAL      = 0b11_011
		OPC_SYSTEM   = 0b11_100

		if opc == OPC_OP:
			if log: log_reg_str = f" x{regnum_rd}, x{regnum_rs1}, x{regnum_rs2}"
			if funct7 == 0b00_00000:
				if funct3 == 0b000:
					if log: log_disasm = "add" + log_reg_str
					rd_wdata = rs1 + rs2
				elif funct3 == 0b001:
					if log: log_disasm = "sll" + log_reg_str
					rd_wdata = rs1 << (rs2 & 0x1f)
				elif funct3 == 0b010:
					if log: log_disasm = "slt" + log_reg_str
					rd_wdata = rs1 < rs2
				elif funct3 == 0b011:
					if log: log_disasm = "sltu" + log_reg_str
					rd_wdata = (rs1 & XLEN_MASK) < (rs2 & XLEN_MASK)
				elif funct3 == 0b100:
					if log: log_disasm = "xor" + log_reg_str
					rd_wdata = rs1 ^ rs2
				elif funct3 == 0b101:
					if log: log_disasm = "srl" + log_reg_str
					rd_wdata = (rs1 & XLEN_MASK) >> (rs2 & 0x1f)
				elif funct3 == 0b110:
					if log: log_disasm = "or" + log_reg_str
					rd_wdata = rs1 | rs2
				elif funct3 == 0b111:
					if log: log_disasm = "and" + log_reg_str
					rd_wdata = rs1 & rs2
				else:
					instr_invalid = True
			elif funct7 == 0b01_00000:
				if funct3 == 0b000:
					if log: log_disasm = "sub" + log_reg_str
					rd_wdata = rs1 - rs2
				elif funct3 == 0b101:
					if log: log_disasm = "sra" + log_reg_str
					rd_wdata = rs1 >> (rs2 & 0x1f)
				else:
					instr_invalid = True
			elif funct7 == 0b00_00001:
				if funct3 < 0b100:
					if log:
						mul_instr_name = {0b000: "mul", 0b001: "mulh", 0b010: "mulhsu", 0b011: "mulhu"}[funct3]
						log_disasm = f"{mul_instr_name} x{regnum_rd}, x{regnum_rs1}, x{regnum_rs2}"
					mul_op_a = rs1 & XLEN_MASK if funct3 == 0b011 else rs1
					mul_op_b = rs2 & XLEN_MASK if funct3 in (0b010, 0b011) else rs2
					mul_result = mul_op_a * mul_op_b
					if funct3 != 0b000:
						mul_result >>= 32
					rd_wdata = sext(mul_result, XLEN - 1)
				else:
					if log:
						div_instr_name = {0b100: "div", 0b101: "divu", 0b110: "rem", 0b111: "remu"}[funct3]
						log_disasm = f"{div_instr_name} x{regnum_rd}, x{regnum_rs1}, x{regnum_rs2}"
					if funct3 == 0b100:
						rd_wdata = -1 if rs2 == 0 else int(rs1 / rs2)
					elif funct3 == 0b101:
						rd_wdata = -1 if rs2 == 0 else sext((rs1 & XLEN_MASK) // (rs2 & XLEN_MASK), XLEN - 1)
					elif funct3 == 0b110:
						rd_wdata = rs1 if rs2 == 0 else rs1 - int(rs1 / rs2) * rs2
					elif funct3 == 0b111:
						rd_wdata = rs1 if rs2 == 0 else sext((rs1 & XLEN_MASK) % (rs2 & XLEN_MASK), XLEN - 1)
					else:
						instr_invalid = True
					stall_cycles = 18
			else:
				instr_invalid = True
			if not instr_invalid:
				stall_cycles += regnum_rs1 == self.stage3_result or regnum_rs2 == self.stage3_result

		elif opc == OPC_OP_IMM:
			imm = (instr >> 20) - (instr >> 19 & 0x1000) # imm_i(instr)
			if funct3 == 0b000:
				if log: log_disasm = f"addi x{regnum_rd}, x{regnum_rs1}, {imm}"
				rd_wdata = rs1 + imm
			elif funct3 == 0b010:
				if log: log_disasm = f"slti x{regnum_rd}, x{regnum_rs1}, {imm}"
				rd_wdata = 1 * (rs1 < imm)
			elif funct3 == 0b011:
				if log: log_disasm = f"sltiu x{regnum_rd}, x{regnum_rs1}, {imm & XLEN_MASK}"
				rd_wdata = 1 * (rs1 & XLEN_MASK < imm & XLEN_MASK)
			elif funct3 == 0b100:
				if log: log_disasm = f"xori x{regnum_rd}, x{regnum_rs1}, 0x{imm & XLEN_MASK:x}"
				rd_wdata = rs1 ^ imm
			elif funct3 == 0b110:
				if log: log_disasm = f"ori x{regnum_rd}, x{regnum_rs1}, 0x{imm & XLEN_MASK:x}"
				rd_wdata = rs1 | imm
			elif funct3 == 0b111:
				if log: log_disasm = f"andi x{regnum_rd}, x{regnum_rs1}, 0x{imm & XLEN_MASK:x}"
				rd_wdata = rs1 & imm
			elif funct3 == 0b001 or funct3 == 0b101:
				# shamt is regnum_rs2
				if funct7 == 0b00_00000 and funct3 == 0b001:
					if log: log_disasm = f"slli x{regnum_rd}, x{regnum_rs1}, {regnum_rs2}"
					rd_wdata = rs1 << regnum_rs2
				elif funct7 == 0b00_00000 and funct3 == 0b101:
					if log: log_disasm = f"srli x{regnum_rd}, x{regnum_rs1}, {regnum_rs2}"
					rd_wdata = (rs1 & XLEN_MASK) >> regnum_rs2
				elif funct7 == 0b01_00000 and funct3 == 0b101:
					if log: log_disasm = f"srai x{regnum_rd}, x{regnum_rs1}, {regnum_rs2}"
					rd_wdata = rs1 >> regnum_rs2
				else:
					instr_invalid = True
			else:
				instr_invalid = True
			if not instr_invalid:
				stall_cycles += regnum_rs1 == self.stage3_result

		elif opc == OPC_JAL:
			rd_wdata = self.pc + 4
			# pc_wdata = self.pc + imm_j(instr)
			pc_wdata = self.pc + (instr >> 20 & 0x7fe) + (instr >> 9 & 0x800) + (instr & 0xff000) - (instr >> 11 & 0x100000)
			if log: log_disasm = f"jal x{regnum_rd}, {pc_wdata & XLEN_MASK:08x}"
			stall_cycles = 1

		elif opc == OPC_JALR:
			stall_cycles += regnum_rs1 == self.stage3_result
			imm = imm_i(instr)
			if log: log_disasm = f"jalr x{regnum_rd}, x{regnum_rs1}, {imm}"
			rd_wdata = self.pc + 4
			# JALR clears LSB always
			pc_wdata = (rs1 + imm) & -2
			stall_cycles = 1

		elif opc == OPC_BRANCH:
			# target = self.pc + imm_b(instr)
			target = self.pc + (instr >> 7 & 0x1e) + (instr >> 20 & 0x7e0) + (instr << 4 & 0x800) - (instr >> 19 & 0x1000)
			taken = False
			if log: log_branch_str = f" x{regnum_rs1}, x{regnum_rs2}, {target:08x}"
			if funct3 == 0b000:
				if log: log_disasm = "beq" + log_branch_str
				taken = rs1 == rs2
			elif funct3 == 0b001:
				if log: log_disasm = "bne" + log_branch_str
				taken = rs1 != rs2
			elif funct3 == 0b100:
				if log: log_disasm = "blt" + log_branch_str
				taken = rs1 < rs2
			elif funct3 == 0b101:
				if log: log_disasm = "bge" + log_branch_str
				taken = rs1 >= rs2
			elif funct3 == 0b110:
				if log: log_disasm = "bltu" + log_branch_str
				taken = (rs1 & XLEN_MASK) < (rs2 & XLEN_MASK)
			elif funct3 == 0b111:
				if log: log_disasm = "bgeu" + log_branch_str
				taken = (rs1 & XLEN_MASK) >= (rs2 & XLEN_MASK)
			else:
				instr_invalid = True
			if not instr_invalid:
				stall_cycles += regnum_rs1 == self.stage3_result or regnum_rs2 == self.stage3_result

			predicted_taken = self.btb_valid and self.pc == self.btb_pc
			stall_cycles += taken != predicted_taken
			if taken:
				pc_wdata = target
				if target < self.pc:
					self.btb_valid = True
					self.btb_pc = self.pc
			elif predicted_taken:
				self.btb_valid = False


		elif opc == OPC_LOAD:
			imm = imm_i(instr)
			if log: log_load_str = f" x{regnum_rd}, {imm}(x{regnum_rs1})"
			load_addr = imm + rs1 & XLEN_MASK
			if funct3 == 0b000:
				if log: log_disasm = "lb" + log_load_str
				rd_wdata = self.mem.get8(load_addr)
				rd_wdata -= rd_wdata << 1 & 0x100
			elif funct3 == 0b001:
				if log: log_disasm = "lh" + log_load_str
				rd_wdata = self.mem.get16(load_addr)
				rd_wdata -= rd_wdata << 1 & 0x10000
			elif funct3 == 0b010:
				if log: log_disasm = "lw" + log_load_str
				rd_wdata = self.mem.get32(load_addr)
				rd_wdata -= rd_wdata << 1 & 0x100000000
			elif funct3 == 0b100:
				if log: log_disasm = "lbu" + log_load_str
				rd_wdata = self.mem.get8(load_addr)
			elif funct3 == 0b101:
				if log: log_disasm = "lhu" + log_load_str
				rd_wdata = self.mem.get16(load_addr)
			else:
				instr_invalid = True
			if not instr_invalid:
				stall_cycles += regnum_rs1 == self.stage3_result
				stage3_result_next = regnum_rd

		elif opc == OPC_STORE:
			imm = imm_s(instr)
			if log: log_store_str = f" x{regnum_rs2}, {imm}(x{regnum_rs1})"
			store_addr = imm + rs1 & XLEN_MASK
			if funct3 == 0b000:
				if log: log_disasm = "sb" + log_store_str
				self.mem.put8(store_addr, rs2 & (1 << 8) - 1)
			elif funct3 == 0b001:
				if log: log_disasm = "sh" + log_store_str
				self.mem.put16(store_addr, rs2 & (1 << 16) - 1)
			elif funct3 == 0b010:
				if log: log_disasm = "sw" + log_store_str
				self.mem.put32(store_addr, rs2)
			else:
				instr_invalid = True
			if not instr_invalid:
				stall_cycles += regnum_rs1 == self.stage3_result

		elif opc == OPC_LUI:
			imm = imm_u(instr)
			if log: log_disasm = f"lui x{regnum_rd}, 0x{(imm & XLEN_MASK) >> 12:05x}"
			rd_wdata = imm

		elif opc == OPC_AUIPC:
			imm = imm_u(instr)
			if log: log_disasm = f"auipc x{regnum_rd}, 0x{(imm & XLEN_MASK) >> 12:05x}"
			rd_wdata = self.pc + imm

		elif opc == OPC_SYSTEM:
			csr_addr = extract(instr, 31, 20)
			if funct3 == 0b000 and funct7 == 0b00_00000:
				if regnum_rs2 == 0:
					if log: log_disasm = "*UNHANDLED* ecall"
					pass
				elif regnum_rs2 == 1:
					if log: log_disasm = "*UNHANDLED* ebreak"
					pass
				else:
					instr_invalid = True
			elif funct3 in (0b001, 0b010, 0b011):
				if log:
					instr_name = {0b001: "csrrw", 0b010: "csrrs", 0b011: "csrrc"}[funct3]
					log_disasm = f"{instr_name} x{regnum_rd}, 0x{csr_addr:x}, x{regnum_rs2}"
				csr_write_op = funct3 - 0b001
				if csr_write_op != RVCSR.WRITE or regnum_rd != 0:
					rd_wdata = self.csr.read(csr_addr)
				if csr_write_op == RVCSR.WRITE or rs2 != 0:
					self.csr.write(csr_addr, rs2, op=csr_write_op)
				stall_cycles += regnum_rs1 == self.stage3_result
			elif funct3 in (0b101, 0b110, 0b111):
				if log:
					instr_name = {0b101: "csrrwi", 0b110: "csrrsi", 0b111: "csrrci"}[funct3]
					log_disasm = f"{instr_name} x{regnum_rd}, 0x{csr_addr:x}, 0x{regnum_rs2:x}"
				csr_write_op = funct3 = 0b101
				if csr_write_op != RVCSR.WRITE or regnum_rd != 0:
					rd_wdata = self.csr.read(csr_addr)
				if csr_write_op == RVCSR.WRITE or regnum_rs2 != 0:
					self.csr.write(csr_addr, rs2, op=csr_write_op)
			else:
				instr_invalid = True


		elif opc == OPC_MISC_MEM:
			if instr == 0b0000_0000_0000_00000_001_00000_0001111:
				if log: log_disasm = "fence.i"
				pass
			elif (instr & 0b1111_0000_0000_11111_111_11111_1111111) == 0b0000_0000_0000_00000_000_00000_0001111:
				if log: log_disasm = f"fence {extract(instr, 27, 24):04b}, {extract(instr, 23, 20):04b}"
				pass
		else:
			instr_invalid = True

		if log:
			log_str = f"{self.pc:08x}: ({instr:08x})   {log_disasm if log_disasm is not None else '':<25}"
			if rd_wdata is not None and regnum_rd != 0:
				log_str += f" : x{regnum_rd:<2} <- {rd_wdata & XLEN_MASK:08x}"
			else:
				log_str += " : " + 15 * " "
			if pc_wdata is not None:
				log_str += f" : pc  <- {pc_wdata & XLEN_MASK:08x}"
			else:
				log_str += " :"
			print(log_str)

		if rd_wdata is not None and regnum_rd != 0:
			self.regs[regnum_rd] = (rd_wdata & 0xffffffff) - (rd_wdata << 1 & 0x100000000)

		if pc_wdata is None:
			self.pc = self.pc + 4
		else:
			self.pc = pc_wdata

		if instr_invalid:
			print(f"Invalid instruction at {self.pc:08x}: {instr:08x}")

		self.csr.step()
		if cycle_accurate:
			self.csr.mcycle += stall_cycles
		self.stage3_result = stage3_result_next if stage3_result_next != 0 else None



def anyint(x):
	return int(x, 0)

def main(argv):
	parser = argparse.ArgumentParser()
	parser.add_argument("binfile")
	parser.add_argument("--memsize", default = 1 << 24, type = anyint)
	parser.add_argument("--cycles", default = int(1e4), type = anyint)
	parser.add_argument("--dump", nargs=2, action="append", type=anyint)
	parser.add_argument("--quiet", "-q", action="store_true")
	args = parser.parse_args(argv)
	if args.quiet:
		mem = MemWithTBIO(args.memsize, io_log_fmt="{}")
	else:
		mem = MemWithTBIO(args.memsize)
	mem.loadbin(open(args.binfile, "rb"), 0)
	rv = RVCore(mem)
	try:
		for i in range(args.cycles):
			rv.step(log=not args.quiet)
	except TBExit as e:
		print(f"CPU requested halt. Exit code {e}")
	except BrokenPipeError as e:
		sys.exit(0)
	print(f"Ran for {i + 1} cycles")

	for start, end in args.dump or []:
		print(f"Dumping memory from {start:08x} to {end:08x}:")
		for i, addr in enumerate(range(start, end)):
			sep = "\n" if i % 16 == 15 else " "
			sys.stdout.write(f"{mem.get8(addr):02x}{sep}")
		print("")

if __name__ == "__main__":
	main(sys.argv[1:])
